In [1]:
import os
import pywt
#from wavelets.wave_python.waveletFunctions import *
import itertools
import numpy as np
import pandas as pd
from scipy.fftpack import fft
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
In [2]:
dataset = "https://raw.githubusercontent.com/taspinar/siml/master/datasets/sst_nino3.dat.txt"
df_nino = pd.read_table(dataset)
N = df_nino.shape[0]
t0=1871
dt=0.25
time = np.arange(0, N) * dt + t0
signal = df_nino.values.squeeze()
In [3]:
# First lets load the el-Nino dataset, and plot it together with its time-average
def get_ave_values(xvalues, yvalues, n = 5):
signal_length = len(xvalues)
if signal_length % n == 0:
padding_length = 0
else:
padding_length = n - signal_length//n % n
xarr = np.array(xvalues)
yarr = np.array(yvalues)
xarr.resize(signal_length//n, n)
yarr.resize(signal_length//n, n)
xarr_reshaped = xarr.reshape((-1,n))
yarr_reshaped = yarr.reshape((-1,n))
x_ave = xarr_reshaped[:,0]
y_ave = np.nanmean(yarr_reshaped, axis=1)
return x_ave, y_ave
def plot_signal_plus_average(ax, time, signal, average_over = 5):
time_ave, signal_ave = get_ave_values(time, signal, average_over)
ax.plot(time, signal, label='signal')
ax.plot(time_ave, signal_ave, label = 'time average (n={})'.format(5))
ax.set_xlim([time[0], time[-1]])
ax.set_ylabel('Amplitude', fontsize=16)
ax.set_title('Signal + Time Average', fontsize=16)
ax.legend(loc='upper right')
fig, ax = plt.subplots(figsize=(12,3))
plot_signal_plus_average(ax, time, signal, average_over = 3)
plt.show()
In [4]:
def get_fft_values(y_values, T, N, f_s):
N2 = 2 ** (int(np.log2(N)) + 1) # round up to next highest power of 2
f_values = np.linspace(0.0, 1.0/(2.0*T), N2//2)
fft_values_ = fft(y_values)
fft_values = 2.0/N2 * np.abs(fft_values_[0:N2//2])
return f_values, fft_values
def plot_fft_plus_power(ax, time, signal, plot_direction='horizontal', yticks=None, ylim=None):
dt = time[1] - time[0]
N = len(signal)
fs = 1/dt
variance = np.std(signal)**2
f_values, fft_values = get_fft_values(signal, dt, N, fs)
fft_power = variance * abs(fft_values) ** 2
if plot_direction == 'horizontal':
ax.plot(f_values, fft_values, 'r-', label='Fourier Transform')
ax.plot(f_values, fft_power, 'k--', linewidth=1, label='FFT Power Spectrum')
elif plot_direction == 'vertical':
scales = 1./f_values
scales_log = np.log2(scales)
ax.plot(fft_values, scales_log, 'r-', label='Fourier Transform')
ax.plot(fft_power, scales_log, 'k--', linewidth=1, label='FFT Power Spectrum')
ax.set_yticks(np.log2(yticks))
ax.set_yticklabels(yticks)
ax.invert_yaxis()
ax.set_ylim(ylim[0], -1)
ax.legend()
fig, ax = plt.subplots(figsize=(12,3))
ax.set_xlabel('Frequency [Hz / year]', fontsize=18)
ax.set_ylabel('Amplitude', fontsize=18)
plot_fft_plus_power(ax, time, signal)
plt.show()
In [5]:
def plot_wavelet(ax, time, signal, scales, waveletname = 'cmor',
cmap = plt.cm.seismic, title = '', ylabel = '', xlabel = ''):
dt = time[1] - time[0]
[coefficients, frequencies] = pywt.cwt(signal, scales, waveletname, dt)
power = (abs(coefficients)) ** 2
period = 1. / frequencies
levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8]
contourlevels = np.log2(levels)
im = ax.contourf(time, np.log2(period), np.log2(power), contourlevels, extend='both',cmap=cmap)
ax.set_title(title, fontsize=20)
ax.set_ylabel(ylabel, fontsize=18)
ax.set_xlabel(xlabel, fontsize=18)
yticks = 2**np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))
ax.set_yticks(np.log2(yticks))
ax.set_yticklabels(yticks)
ax.invert_yaxis()
ylim = ax.get_ylim()
ax.set_ylim(ylim[0], -1)
return yticks, ylim
scales = np.arange(1, 128)
title = 'Wavelet Transform (Power Spectrum) of signal'
ylabel = 'Period (years)'
xlabel = 'Time'
fig, ax = plt.subplots(figsize=(10, 10))
plot_wavelet(ax, time, signal, scales, xlabel=xlabel, ylabel=ylabel, title=title)
plt.show()
In [6]:
fig = plt.figure(figsize=(12,12))
spec = gridspec.GridSpec(ncols=6, nrows=6)
top_ax = fig.add_subplot(spec[0, 0:5])
bottom_left_ax = fig.add_subplot(spec[1:, 0:5])
bottom_right_ax = fig.add_subplot(spec[1:, 5])
plot_signal_plus_average(top_ax, time, signal, average_over = 3)
yticks, ylim = plot_wavelet(bottom_left_ax, time, signal, scales, xlabel=xlabel, ylabel=ylabel, title=title)
plot_fft_plus_power(bottom_right_ax, time, signal, plot_direction='vertical', yticks=yticks, ylim=ylim)
bottom_right_ax.set_ylabel('Period [years]', fontsize=14)
plt.tight_layout()
plt.show()
Also have a look at: http://nicolasfauchereau.github.io/climatecode/posts/wavelet-analysis-in-python/
In [ ]: